# -*- coding: utf-8 -*-
from keras.callbacks import Callback
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score


class DistanceMetrics(Callback):
    def __init__(self):
        super(DistanceMetrics, self).__init__()

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []
        self.val_aucs = []

    def on_epoch_end(self, epoch, logs={}):
        valid_results = self.model.predict([self.validation_data[0], self.validation_data[1]])
        valid_y_pred = [valid_result[0] < 0.65 for valid_result in valid_results]
        valid_y_true = self.validation_data[2]
        _val_f1 = f1_score(valid_y_true, valid_y_pred)
        _val_recall = recall_score(valid_y_true, valid_y_pred)
        _val_precision = precision_score(valid_y_true, valid_y_pred)
        _val_auc = roc_auc_score(valid_y_true, valid_y_pred)
        logs['val_precisions'] = _val_precision
        logs['val_recall'] = _val_recall
        logs['val_f1'] = _val_f1
        logs['val_auc'] = _val_auc
        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)
        self.val_aucs.append(_val_auc)
        print('- val_precision: %.4f - val_recall: %.4f  - val_f1: %.4f - val_auc: %.4f' %
              (_val_precision, _val_recall, _val_f1, _val_auc))
        return


class CategoricalMetrics(Callback):
    def __init__(self):
        super(CategoricalMetrics, self).__init__()

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []
        self.val_aucs = []

    def on_epoch_end(self, epoch, logs={}):
        valid_results = self.model.predict([self.validation_data[0], self.validation_data[1]])
        valid_y_pred = [valid_result[0] > 0.4 for valid_result in valid_results]
        valid_y_true = self.validation_data[2]
        _val_f1 = f1_score(valid_y_true, valid_y_pred)
        _val_recall = recall_score(valid_y_true, valid_y_pred)
        _val_precision = precision_score(valid_y_true, valid_y_pred)
        _val_auc = roc_auc_score(valid_y_true, valid_y_pred)
        logs['val_precisions'] = _val_precision
        logs['val_recall'] = _val_recall
        logs['val_f1'] = _val_f1
        logs['val_auc'] = _val_auc
        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)
        self.val_aucs.append(_val_auc)
        print('- val_precision: %.4f - val_recall: %.4f  - val_f1: %.4f - val_auc: %.4f' %
              (_val_precision, _val_recall, _val_f1, _val_auc))
        return


class DistanceMixedMetrics(Callback):
    def __init__(self):
        super(DistanceMixedMetrics, self).__init__()

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []
        self.val_aucs = []

    def on_epoch_end(self, epoch, logs={}):
        valid_results = self.model.predict([self.validation_data[0], self.validation_data[1],
                                            self.validation_data[2], self.validation_data[3]])
        valid_y_pred = [valid_result[0] < 0.65 for valid_result in valid_results]
        valid_y_true = self.validation_data[4]
        _val_f1 = f1_score(valid_y_true, valid_y_pred)
        _val_recall = recall_score(valid_y_true, valid_y_pred)
        _val_precision = precision_score(valid_y_true, valid_y_pred)
        _val_auc = roc_auc_score(valid_y_true, valid_y_pred)
        logs['val_precisions'] = _val_precision
        logs['val_recall'] = _val_recall
        logs['val_f1'] = _val_f1
        logs['val_auc'] = _val_auc
        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)
        self.val_aucs.append(_val_auc)
        print('- val_precision: %.4f - val_recall: %.4f  - val_f1: %.4f - val_auc: %.4f' %
              (_val_precision, _val_recall, _val_f1, _val_auc))
        return


class CategoricalMixedMetrics(Callback):
    def __init__(self):
        super(CategoricalMixedMetrics, self).__init__()

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []
        self.val_aucs = []

    def on_epoch_end(self, epoch, logs={}):
        valid_results = self.model.predict([self.validation_data[0], self.validation_data[1],
                                            self.validation_data[2], self.validation_data[3]])
        valid_y_pred = [valid_result[0] > 0.4 for valid_result in valid_results]
        valid_y_true = self.validation_data[4]
        _val_f1 = f1_score(valid_y_true, valid_y_pred)
        _val_recall = recall_score(valid_y_true, valid_y_pred)
        _val_precision = precision_score(valid_y_true, valid_y_pred)
        _val_auc = roc_auc_score(valid_y_true, valid_y_pred)
        logs['val_precisions'] = _val_precision
        logs['val_recall'] = _val_recall
        logs['val_f1'] = _val_f1
        logs['val_auc'] = _val_auc
        self.val_f1s.append(_val_f1)
        self.val_recalls.append(_val_recall)
        self.val_precisions.append(_val_precision)
        self.val_aucs.append(_val_auc)
        print('- val_precision: %.4f - val_recall: %.4f  - val_f1: %.4f - val_auc: %.4f' %
              (_val_precision, _val_recall, _val_f1, _val_auc))
        return


distance_metrics = DistanceMetrics()
categorical_metrics = CategoricalMetrics()

distance_mixed_metrics = DistanceMixedMetrics()
categorical_mixed_metrics = CategoricalMixedMetrics()
